package org.example.offline

import breeze.numerics.sqrt
import org.apache.spark.SparkConf
import org.apache.spark.mllib.recommendation.{ALS, MatrixFactorizationModel, Rating}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SparkSession
import org.example.java.model.Constant._
import org.example.scala.model.{MongoConfig, MovieRating}

object ALSTrainer {

  def main(args: Array[String]): Unit = {

    val config = Map(
      "spark.cores" -> "local[*]",
      "mongo.uri" -> "mongodb://localhost:27018/recommender",
      "mongo.db" -> "recommender"
    )

    //    创建SparkConf
    val sparkConf = new SparkConf().setAppName("ALSTrainer").setMaster(config("spark.cores"))

    //    创建SparkSession
    val spark = SparkSession.builder().config(sparkConf).getOrCreate()

    //    创建MongoDBConf配置
    val mongoConf = MongoConfig(config("mongo.uri"), config("mongo.db"))

    import spark.implicits._

    //    加载评分数据
    val ratingRDD = spark
      .read
      .option("uri", mongoConf.uri)
      .option("collection", MONGO_RATING_COLLECTION)
      .format(MONGO_DRIVER_CLASS)
      .load()
      .as[MovieRating]
      .rdd
      .map(rating => Rating(rating.uid, rating.mid, rating.score))
      .cache() //避免爆内存

    //    输出最优参数
    adjustALSParams(ratingRDD)

    //    关闭Spark
    spark.stop()
  }

  //  输出最终的最优参数
  def adjustALSParams(trainData: RDD[Rating]): Unit = {

    val result = for (rank <- Array(30, 40, 50, 60, 70); lambda <- Array(1, 0.1, 0.001))
      yield {
        val model = ALS.train(trainData, rank, 5, lambda)
        //        均方根误差
        val rmse = getRmse(model, trainData)
        (rank, lambda, rmse)
      }


    println(result.sortBy(_._3).head)
  }

  def getRmse(model: MatrixFactorizationModel, trainData: RDD[Rating]): Double = {
    //需要构造一个userProducts RDD[(Int, Int)]
    val userMovies = trainData.map(item => (item.user, item.product))
    val predictRating = model.predict(userMovies)

    val real = trainData.map(item => ((item.user, item.product), item.rating))
    val predict = predictRating.map(item => ((item.user, item.product), item.rating))

    // 计算 RMSE
    sqrt(
      real.join(predict).map { case ((uid, mid), (real, pre)) =>
        // 真实值和预测值之间的差
        val err = real - pre
        err * err
      }.mean()
    )
  }

}
